{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02145.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02145.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N_X2yQs42uCG",
        "colab_type": "text"
      },
      "source": [
        "## 6.8 长短期记忆 ( LSTM )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LITcA21729_h",
        "colab_type": "text"
      },
      "source": [
        "### 6.8.2 读取数据"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TVwQUwyw3TMY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np \n",
        "import torch \n",
        "from torch import nn, optim \n",
        "import torch.nn.functional as F \n",
        "import d2l \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GCv4ThaZ5rXA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!mkdir ../../data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gJslTo4L5rHR",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "b082f0db-e30d-4bab-f7c0-4346e5219b0b"
      },
      "source": [
        "!git clone https://github.com/ShusenTang/Dive-into-DL-PyTorch"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'Dive-into-DL-PyTorch'...\n",
            "remote: Enumerating objects: 1692, done.\u001b[K\n",
            "remote: Total 1692 (delta 0), reused 0 (delta 0), pack-reused 1692\n",
            "Receiving objects: 100% (1692/1692), 25.29 MiB | 5.02 MiB/s, done.\n",
            "Resolving deltas: 100% (975/975), done.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nUfGXUUC5q4O",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!cp Dive-into-DL-PyTorch/data/jaychou_lyrics.txt.zip ../../data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oclpXfyY6HQD",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AniMJiAw3spX",
        "colab_type": "text"
      },
      "source": [
        "### 6.8.3 从零开始实现"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9QaYBvoE3w5e",
        "colab_type": "text"
      },
      "source": [
        "#### 1 初始化模型参数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "chc_Lugb30de",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "7e3e3a7b-4441-4659-9e27-0421c80ac3b7"
      },
      "source": [
        "num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size \n",
        "print('will use', device)\n",
        "\n",
        "def get_params():\n",
        "    def _one(shape):\n",
        "        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)\n",
        "        return torch.nn.Parameter(ts, requires_grad=True)\n",
        "    def _three():\n",
        "        return (_one((num_inputs, num_hiddens)), \n",
        "            _one((num_hiddens, num_hiddens)), \n",
        "            torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))\n",
        "\n",
        "    W_xi, W_hi, b_i = _three()\n",
        "    W_xf, W_hf, b_f = _three()\n",
        "    W_xo, W_ho, b_o = _three()\n",
        "    W_xc, W_hc, b_c = _three()\n",
        "\n",
        "    W_hq = _one((num_hiddens, num_outputs))\n",
        "    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)\n",
        "    return nn.ParameterList([W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q])"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "will use cuda\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ELLDOJrF5f6s",
        "colab_type": "text"
      },
      "source": [
        "#### 2 定义模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7ZS5WQch68Dt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def init_lstm_state(batch_size, num_hiddens, device):\n",
        "    return (torch.zeros((batch_size, num_hiddens), device=device), \n",
        "        torch.zeros((batch_size, num_hiddens), device=device))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J8J-7NHl7QCn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def lstm(inputs, state, params):\n",
        "    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params \n",
        "    (H, C) = state \n",
        "    outputs = []\n",
        "    for X in inputs:\n",
        "        I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i)\n",
        "        F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f)\n",
        "        O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o)\n",
        "        C_tilda = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)\n",
        "        C = F * C + I * C_tilda \n",
        "        H = O * C.tanh()\n",
        "        Y = torch.matmul(H, W_hq) + b_q\n",
        "        outputs.append(Y)\n",
        "    return outputs, (H, C)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cxdaocYO8sqC",
        "colab_type": "text"
      },
      "source": [
        "#### 3 训练模型并创作歌词"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wkbdKkgn8zy8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2\n",
        "pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "33vUTiwM9FO8",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 218
        },
        "outputId": "5b834be1-a3a0-4342-a2cd-fca3063c5cf8"
      },
      "source": [
        "d2l.train_and_predict_rnn(lstm, get_params, init_lstm_state, num_hiddens, vocab_size, device, corpus_indices, idx_to_char,\n",
        "            char_to_idx, False, num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, \n",
        "            prefixes)"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 40, perplexity 213.371680, time 0.60 sec\n",
            " - 分开 我不的 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 \n",
            " - 不分开 我不的 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 我不 \n",
            "epoch 80, perplexity 65.394913, time 0.60 sec\n",
            " - 分开 我想我这你的你 我想 你想你想你 我想 我想你的爱我 你不 我想我想你的你 爱爱我 你你我 我不要\n",
            " - 不分开 你不我 你不我 我不要 我不要 我不我 你不我 我不要 我不我 你不我 我不要 我不我 你不我 我\n",
            "epoch 120, perplexity 15.610854, time 0.60 sec\n",
            " - 分开 我想你这你 我不要这样 我不 我不 我不能 爱情走的太快就像龙卷风 不不能能我 我不要再想 我不 \n",
            " - 不分开 你这你的生笑 一天  说你的话笑 一通  说你的话笑 快你在对我的你的我 我想揍你已你 这这不这样\n",
            "epoch 160, perplexity 3.942796, time 0.61 sec\n",
            " - 分开 我说带你 我想多难熬  没有你在我有多难熬多烦恼  没有你烦 我有著努恼 你有 我不我的爱笑 我爱\n",
            " - 不分开你的让 我想要你已你很你 想要再你样打我妈妈 我道的你 你甘著听 不要再这样我妈妈 我想你的你甘你 \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d15WpneC9p0i",
        "colab_type": "text"
      },
      "source": [
        "### 6.8.4 简洁实现"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CNHgn7FM9yjv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 218
        },
        "outputId": "b3a8dd95-e3ec-46c3-b9c9-d25968554828"
      },
      "source": [
        "lr = 1e-2 \n",
        "lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)\n",
        "model = d2l.RNNModel(lstm_layer, vocab_size)\n",
        "d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, char_to_idx, \n",
        "                num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes)"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 40, perplexity 1.019573, time 0.15 sec\n",
            " - 分开始移动 回到当初爱你的时空 停格内容不忠 所有回忆对着我进攻 我的伤口被你拆封 誓言太沉重泪被纵容 \n",
            " - 不分开 我不能 爱情走的太快就像龙卷风 不能承受我已无处可躲 我不要再想 我不要再想 我不 我不 我不要再\n",
            "epoch 80, perplexity 1.012941, time 0.15 sec\n",
            " - 分开始移动 回到当初爱你的时空 停格内容不忠 所有回忆对着我进攻 我的伤口被你拆封 誓言太沉重泪被纵容 \n",
            " - 不分开 我不到 整个人对就怎么小 就怎么每天祈祷我的心跳你知道  杵在伊斯坦堡 却只想你和汉堡 我想要你的\n",
            "epoch 120, perplexity 1.012826, time 0.15 sec\n",
            " - 分开始移动 可爱女人 坏坏的让我疯狂的可爱女人 坏坏的让我疯狂的可爱女人 漂亮的让我面红的可爱女人 温柔\n",
            " - 不分开 我不到 整颗心悬在半空 我只能够远远看著 这些我都做得到 但那个人已经不是我 没有你在 我有多难熬\n",
            "epoch 160, perplexity 1.012320, time 0.15 sec\n",
            " - 分开始移动 回到当初爱你的时空 停格内容不忠 所有回忆对着我进攻 我的伤口被你拆封 誓言太沉重泪被纵容 \n",
            " - 不分开 我不到 爱情我的太快就像龙卷风 不能承受我已无处可躲 我不要再想 我不要再想 我不 我不 我不要再\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}